/*
 * AccumulativeChecksumUtil.h
 *
 * This source file is part of the FoundationDB open source project
 *
 * Copyright 2013-2024 Apple Inc. and the FoundationDB project authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ACCUMULATIVECHECKSUMUTIL_H
#define ACCUMULATIVECHECKSUMUTIL_H
#pragma once

#include "fdbclient/AccumulativeChecksum.h"
#include "fdbclient/CommitTransaction.h"
#include "fdbclient/SystemData.h"

static const uint16_t invalidAccumulativeChecksumIndex = 0;
static const uint16_t resolverAccumulativeChecksumIndex = 2;
static const uint32_t initialAccumulativeChecksum = 0;

// Define the mapping from commitProxy Index to acsIndex
inline uint16_t getCommitProxyAccumulativeChecksumIndex(uint16_t commitProxyIndex) {
	// We leave flexibility in acs index generated from different components
	// Acs index ends with 1 indicates the mutation comes from a commit proxy
	return commitProxyIndex * 10 + 1;
}

// Define the ACS value calculation
inline uint32_t calculateAccumulativeChecksum(uint32_t currentAccumulativeChecksum, uint32_t newChecksum) {
	return currentAccumulativeChecksum ^ newChecksum;
}

// Define the supported tags to track ACS values
inline bool tagSupportAccumulativeChecksum(Tag tag) {
	// TODO: add log router tag, i.e., -2, so that new backup (backup workers) can be supported.
	return tag.locality >= 0;
}

// Define how to aggregate ACS values of a vector of mutations from a starting ACS
inline uint32_t aggregateAcs(uint32_t startAcs, Standalone<VectorRef<std::pair<Version, MutationRef>>> mutations) {
	uint32_t newAcs = startAcs;
	for (const auto& [version, mutation] : mutations) {
		ASSERT(mutation.checksum.present());
		newAcs = calculateAccumulativeChecksum(newAcs, mutation.checksum.get());
	}
	return newAcs;
}

// A builder to generate accumulative checksum and keep tracking
// the accumulative checksum for each tag
// Currently, accumulative checksum only supports the mutation
// generated by commit proxy and the encryption is disabled with
// storage server tags (aka. locality >= 0)
class AccumulativeChecksumBuilder {
public:
	AccumulativeChecksumBuilder(uint16_t acsIndex) : acsIndex(acsIndex), currentVersion(0) {}

	// Called when commit proxy applies a new tag assignment mutation
	// At this time, this method erases the corresponding ACS value of the tag
	void newTag(Tag tag, UID ssid, Version commitVersion);

	// Called when commit proxy assigning tags to a mutation (e.g. mutation, private mutation)
	// Update ACS value for the input tag assigned to the mutation
	void addMutation(const MutationRef& mutation, Tag tag, LogEpoch epoch, UID commitProxyId, Version commitVersion);

	// Return read-only ACS map
	const std::unordered_map<Tag, AccumulativeChecksumState>& getAcsTable() { return acsTable; }

private:
	uint16_t acsIndex; // Essentially, this is the ID of commit proxy
	Version currentVersion;
	std::unordered_map<Tag, AccumulativeChecksumState> acsTable;

	// Update ACS state of the input tag to the input values
	uint32_t updateTable(Tag tag, uint32_t checksum, Version version, LogEpoch epoch);
};

// This function changes the input mutation by populating checksum and setting ACS index in the mutation ref
// Add the input mutation and the corresponding inputTag to ACS builder
void updateMutationWithAcsAndAddMutationToAcsBuilder(std::shared_ptr<AccumulativeChecksumBuilder> acsBuilder,
                                                     MutationRef& mutation,
                                                     Tag inputTag,
                                                     uint16_t acsIndex,
                                                     LogEpoch epoch,
                                                     Version commitVersion,
                                                     UID commitProxyId);

void updateMutationWithAcsAndAddMutationToAcsBuilder(std::shared_ptr<AccumulativeChecksumBuilder> acsBuilder,
                                                     MutationRef& mutation,
                                                     const std::vector<Tag>& inputTags,
                                                     uint16_t acsIndex,
                                                     LogEpoch epoch,
                                                     Version commitVersion,
                                                     UID commitProxyId);

void updateMutationWithAcsAndAddMutationToAcsBuilder(std::shared_ptr<AccumulativeChecksumBuilder> acsBuilder,
                                                     MutationRef& mutation,
                                                     const std::set<Tag>& inputTags,
                                                     uint16_t acsIndex,
                                                     LogEpoch epoch,
                                                     Version commitVersion,
                                                     UID commitProxyId);

// A validator to check if the accumulative checksum is correct for
// each version that has mutations
class AccumulativeChecksumValidator {
public:
	AccumulativeChecksumValidator() {}

	// Called when SS pulls a non-ACS mutation
	// Add the mutation to the mutation buffer
	void addMutation(const MutationRef& mutation, UID ssid, Tag tag, Version ssVersion, Version mutationVersion);

	// Called when SS receives an ACS mutation
	// Consume the current mutation buffer to generate ACS
	// Validate the generated ACS with the ACS carried by ACS mutation
	// Report error if ACS values are mismatch
	// Update acs table using the ACS mutation
	// Return acs state to persist (a mutation is issued to persist ACS state after this method is called)
	Optional<AccumulativeChecksumState> processAccumulativeChecksum(const AccumulativeChecksumState& acsMutationState,
	                                                                UID ssid,
	                                                                Tag tag,
	                                                                Version ssVersion);

	// Called when SS restores from persisted private data
	// Overwrite existing acsState with the input acsState for the same acsIndex
	void restore(const AccumulativeChecksumState& acsState, UID ssid, Tag tag, Version ssVersion);

	// Called when SS applied pulled mutations in a round
	// At this time, we are not expected to see any mutation
	// in the buffer, since the mutation buffer is consumed
	// by ACS mutation at the end of each version which having
	// mutations. However, in case we are missing ACS mutations,
	// the mutation buffer is not consumed. In this case,
	// force to clear existing mutation buffer to keep the
	// memory usage bounded. ACS value mismatch will be reported
	// when the next ACS mutation arrives
	void clearCache(UID ssid, Tag tag, Version ssVersion);

	// Called when StorageMetrics is generated
	// Return the existing counter and clear the counter,
	// so that StorageMetrics reports the counter accumulated
	// since the last StorageMetrics pops.
	uint64_t getAndClearCheckedMutations();

	uint64_t getAndClearCheckedVersions();

	uint64_t getAndClearTotalMutations();

	uint64_t getAndClearTotalAcsMutations();

	uint64_t getAndClearTotalAddedMutations();

	// Called when SS pulls a mutation
	void incrementTotalMutations() { totalMutations++; }

	// Called when SS receives an ACS mutation
	void incrementTotalAcsMutations() { totalAcsMutations++; }

private:
	std::unordered_map<uint16_t, AccumulativeChecksumState> acsTable;

	// Any mutation is added to mutationBuffer at first. Those mutations
	// will be consumed to generate ACS value until SS receives the first
	// following ACS mutation.
	Standalone<VectorRef<std::pair<Version, MutationRef>>> mutationBuffer;
	uint64_t checkedMutations = 0; // the number of mutations checked by ACS
	uint64_t checkedVersions = 0; // the number of versions checked by ACS
	uint64_t totalMutations = 0; // the number of mutations received by SS
	uint64_t totalAcsMutations = 0; // the number of ACS mutations received by SS
	uint64_t totalAddedMutations = 0; // the number of mutations added to mutationBuffer
};

#endif
